function orbitPropagation()
    % Constants
    mu = 398600; % Gravitational parameter of Earth [km^3/s^2]
    r0 = [8000; 10000; 0]; % Initial position vector [km]
    v0 = [0; 7; 0]; % Initial velocity vector [km/s]
    tspan = [0 86400]; % Time span for propagation [seconds]
    dt = 60; % Time step [seconds]
    
    % Define orders of state transition tensors to compare
    orders = [1, 2, 3, 4, 5];
    
    % Propagate orbit for different orders of tensors
    for i = 1:length(orders)
        order = orders(i);
        
        % Propagate orbit
        [t, r] = propagateOrbit(mu, r0, v0, tspan, dt, order);
        
        % Compute computation error
        error = norm(r(:, end) - r(:, 1));
        
        % Display results
        disp(['Order ', num2str(order), ' - Computation Error: ', num2str(error), ' km']);
        
        % Plot orbit
        figure;
        plot3(r(1, :), r(2, :), r(3, :));
        xlabel('X [km]');
        ylabel('Y [km]');
        zlabel('Z [km]');
        title(['Orbit Propagation - Order ', num2str(order)]);
        grid on;
        axis equal;
    end
end

function [t, r] = propagateOrbit(mu, r0, v0, tspan, dt, order)
    % Compute state transition tensor
    T = stateTransitionTensor(mu, r0, dt, order);
    
    % Preallocate arrays
    numSteps = ceil((tspan(2) - tspan(1)) / dt) + 1;
    t = zeros(1, numSteps);
    r = zeros(3, numSteps);
    
    % Initial conditions
    t(1) = tspan(1);
    r(:, 1) = r0;
    state = [r0; v0]; % Combine position and velocity into a single state vector
    
    % Propagate orbit
    for i = 2:numSteps
        % Update time
        t(i) = t(i-1) + dt;
        
        % Reshape state to a column vector
        state = reshape(state, [], 1);
        
        % Apply state transition tensor in a block-wise manner
        blockSize = size(state, 1);
        numBlocks = size(T, 2) / blockSize;
        newState = zeros(size(state));
        for j = 1:numBlocks
            startIdx = (j - 1) * blockSize + 1;
            endIdx = j * blockSize;
            block = T(:, startIdx:endIdx);
            newState(startIdx:endIdx) = block * state;
        end
        
        % Update state with the new state vector
        state = newState;
        
        % Extract position vector from state
        r(:, i) = state(1:3);
    end
end

function T = stateTransitionTensor(mu, r0, dt, order)
    % Compute state transition tensor using state transition matrix
    
    % Compute state transition matrix
    A = stateTransitionMatrix(mu, r0, dt);
    
    % Compute state transition tensor by tensor product
    T = A;
    for i = 2:order
        T = tensorProduct(T, A, 6);
    end
    
    % Reshape tensor to matrix
    T = reshape(T, [], size(T, 3));
end


function A = stateTransitionMatrix(mu, r0, dt)
    % State transition matrix for 2-body dynamics
    
    r = norm(r0);
    I = eye(3);
    O = zeros(3);
    
    A = [O, I; -mu/r^3 * (I - 3*(r0*r0')/r^2), O];
    A = eye(6) + dt * A;
end

function C = tensorProduct(A, B, dim)
    % Compute the tensor product of A and B
    
    C = zeros(dim, dim, dim, dim);
    for i = 1:dim
        for j = 1:dim
            for k = 1:dim
                for l = 1:dim
                    C(i, j, k, l) = A(i, j) * B(k, l);
                end
            end
        end
    end
    C = reshape(C, dim^2, dim^2);
end